""" Diffuser Implementation """
from typing import Any, Tuple, NamedTuple, List, Dict, Union, Type, Optional, Callable

import gym
import numpy as np

from sb3_jax.common.preprocessing import get_flattened_obs_dim, get_act_dim

from diffgro.environments.collect_dataset import get_skill_embed
from diffgro.diffuser.planner import DiffuserPlanner
from diffgro.diffgro.functions import guide_fn_dict, _loss_txt, _manual_loss_fn
from diffgro.utils import llm
from diffgro.utils import *


class Diffuser:
    guide_methods = ['blank', 'test', 'manual', 'llm']
    def __init__(
        self,
        env: gym.Env,
        planner: DiffuserPlanner,
        history: int = None,    # history to stack
        guide: str = None,      # guide function
        guide_pt: str = None,   # prompt or path for llm guidance
        delta: float = 1.0,     # scale for guidance
        verbose: bool = False,
    ):
        self.env = env
        self.planner = planner.policy
        self.dynamic_inpaint = planner.dynamic_inpaint
        print_b(f"Dynamic Inpaint: {self.dynamic_inpaint}")
        self.history = history
        # guidance
        if guide is not None:
            assert guide in Diffuser.guide_methods, f"Guide method {guide} should be in {Diffuser.guide_methods}"
        self.guide = guide
        self.guide_fn = guide_fn_dict[guide] if guide is not None else None
        self.guide_pt = guide_pt # context
        self.context_info = None
        self.delta = delta
        # misc 
        self.verbose = verbose

        self._setup()

    def _setup(self) -> None:
        self.obs_dim = get_flattened_obs_dim(self.env.observation_space)
        self.act_dim = get_act_dim(self.env.action_space)
        self.horizon = self.planner.horizon 

        # history settting
        if self.history is None:
            self.history = int(self.horizon / 2)
        print_b(f"[diffgro] History stack is set as {self.history}")

        # task embedding
        self.task = get_skill_embed(None, self.env.env_name).reshape(1, -1)
        if self.env.domain_name == 'metaworld_complex':
            self.skill = [get_skill_embed(None, task).reshape(1, -1) for task in self.env.full_task_list]

    def _setup_guide(self) -> None:
        # guidance settings 
        self.n_guide_steps = 1 
        if self.guide == 'test':
            self.loss_fn = [self.guide_fn[self.guide_pt] for _ in range(self.env.task_num)]
        if self.guide == 'blank': # no guidance only for evaluating contexts
            self.n_guide_steps = 0 
        if self.guide == 'manual':
            self.loss_fn, self.guide_pt, self.loss_pt = [], [], []
            for context in self.context_info:
                context_dict = {"context_type": context[2], "context_target": context[3]}
                self.loss_pt.append(context[0])
                self.guide_pt.append(context[0])
                loss_fn, _ = self.guide_fn(**context_dict)
                self.loss_fn.append(loss_fn)
                self.delta = context[4]
        if self.guide  == 'llm':
            pass
        print_b(f"[diffuser] guidance function is '{self.guide}' and scale is '{self.delta}'")
        print_b(f"[diffuser] the guide prompt is '{self.guide_pt}'")

    def reset(self) -> None:
        self.h, self.t = 0, 0
        self.obs_stack = np.zeros((1, self.horizon, self.obs_dim))
        self.act_stack = np.zeros((1, self.horizon, self.act_dim))

    def predict(self, obs: np.ndarray, deterministic: bool = True):
        # add batch dimension
        obs = obs.reshape((-1,) + obs.shape)
        
        if self.dynamic_inpaint:
            # 1-1. conditioning
            self.obs_stack[0][self.h] = obs[0]
            cond = np.concatenate((self.obs_stack, self.act_stack), axis=-1)
            # 1-2. maksing
            mask_obs = np.concatenate((np.ones((1, self.h + 1, self.obs_dim)), np.zeros((1, self.horizon - self.h - 1, self.obs_dim))), axis=1)
            if self.h == 0: mask_act = np.zeros((1, self.horizon, self.act_dim))
            else: mask_act = np.concatenate((np.ones((1, self.h, self.act_dim)), np.zeros((1, self.horizon - self.h, self.act_dim))), axis=1)
            mask = np.concatenate((mask_obs, mask_act), axis=-1)
        else:
            cond, mask = obs, None

        #if self.env.domain_name == 'metaworld_complex':
        #    if self.t > 400:
        #        self.guide_fn = None

        # 2. infrence planner
        task, skill = self.task, None
        if self.env.domain_name == 'metaworld_complex':
            task, skill = self.task, self.skill[self.env.success_count]

        if self.guide_fn is None or self.guide_fn == 'blank':
            plan, act = self.predict_without_guide(cond, task, skill, mask)
        else:
            plan, act = self.predict_with_guide(cond, task, skill, mask)
        
        self.t += 1 

        # 3. stacking
        if self.dynamic_inpaint:
            self.h += 1
            if self.history != 0:
                if self.h == (self.horizon - 1): # if history, we reset for every history step # discrad last state
                    self.obs_stack = np.concatenate((self.obs_stack[:,-self.history-1:-1,:], np.zeros((1, self.horizon - self.history, self.obs_dim))), axis=1)
                    self.act_stack = np.concatenate((self.act_stack[:,-self.history-1:-1,:], np.zeros((1, self.horizon - self.history, self.act_dim))), axis=1)
                    self.h = self.history - 1 
            else:
                self.reset()

        act = np.array(act.copy())
        return act, None, {"guided": self.guided}

    def predict_without_guide(self, cond, task, skill, mask):
        self.guided = False
        plan, info = self.planner._predict(cond, task, skill, mask, delta=None, guide_fn=None, deterministic=True)
        act = plan[:,self.h,-self.act_dim:][0]
        return plan, act

    def predict_with_guide(self, cond, task, skill, mask):
        self.guided = True
        loss_fn = self.loss_fn[0]
        if self.env.domain_name == 'metaworld_complex':
            try:
                loss_fn = self.loss_fn[self.env.success_count]
            except:
                loss_fn = None
                self.guided = False
        plan, info = self.planner._predict(cond, task, skill, mask, delta=self.delta, guide_fn=loss_fn, deterministic=True)
        act = plan[:,self.h,-self.act_dim:][0]
        return plan, act
